{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/03-ce-scoring.ipynb) [![Open nbviewer](https://raw.githubusercontent.com/pinecone-io/examples/master/assets/nbviewer-shield.svg)](https://nbviewer.org/github/pinecone-io/examples/blob/master/learn/analytics-and-ml/model-training/gpl/03-ce-scoring.ipynb)\n",
    "\n",
    "# Cross-Encoder Scoring\n",
    "\n",
    "The final step in preparing our training data for GPL is a cross encoder scoring step. Given all of the query, passage pairs we've generated, both:\n",
    "\n",
    "$$Positives = (Q, P^+)$$\n",
    "\n",
    "<center>and</center>\n",
    "\n",
    "$$Negatives = (Q, P^-)$$\n",
    "\n",
    "We pass these into a cross encoder model that is trained to predict similarity scores of *(Q, P)* pairs.\n",
    "\n",
    "First, we will load our cross encoder model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder at 0x7ff671abbe50>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers import CrossEncoder\n",
    "\n",
    "model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define a generator function to read pairs from file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "def get_lines():\n",
    "    # loop through each file\n",
    "    with open('data/triplets.tsv', 'r', encoding='utf-8') as fp:\n",
    "        lines = fp.read().split('\\n')\n",
    "    # loop through each line in the current file\n",
    "    for line in tqdm(lines):\n",
    "        q, p, n = line.split('\\t')\n",
    "        # return the query, positive, negative\n",
    "        yield q, p, n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use the cross encoder to calculate the similarity of the positive pair, and negative pair, and then take the *score* as the margin between the two similarity scores. We are taking the margin as we will be training our bi-encoder model using Margin MSE loss (which requires the margin/difference).\n",
    "\n",
    "$$ Margin = sim(Q, P^+) - sim(Q, P^-) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "61f52178f8f2431abd642f65a719acfd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/200000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "lines = get_lines()\n",
    "label_lines = []\n",
    "\n",
    "for line in lines:\n",
    "    q, p, n = line\n",
    "    p_score = model.predict((q, p))\n",
    "    n_score = model.predict((q, n))\n",
    "    margin = p_score - n_score\n",
    "    # append pairs to label_lines with margin score\n",
    "    label_lines.append(\n",
    "        q + '\\t' + p + '\\t' + n + '\\t' + str(margin)\n",
    "    )\n",
    "\n",
    "with open(\"data/triplets_margin.tsv\", 'w', encoding='utf-8') as fp:\n",
    "    fp.write('\\n'.join(label_lines))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have our *(Q, P<sup>+</sup>, P<sup>-</sup>)* pairs, we can move on to fine-tuning with MarginMSE loss.\n",
    "\n",
    "---"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "common-cu110.m91",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7 (main, Sep 14 2022, 22:38:23) [Clang 14.0.0 (clang-1400.0.29.102)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
